{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "24326bfc-3ec0-4efd-9c9b-cd8bb00608a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%run -i \"../util/lang_utils.ipynb\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "845abbda-1c9b-4c61-8250-8aa4be058632",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, Dataset, Features, Value, ClassLabel, Sequence, DatasetDict\n",
    "import pandas as pd\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "from transformers import DataCollatorForTokenClassification\n",
    "from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from evaluate import load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a0fd4c34-111e-4174-9aaf-050676b43657",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "        id                                               text  start_offset  \\\n",
      "0    13434  i love radioheads kid a something similar , ki...             7   \n",
      "1    13434  i love radioheads kid a something similar , ki...            61   \n",
      "2    13435                anything similar to i fight dragons            20   \n",
      "3    13436                music similar to ccrs travelin band            17   \n",
      "4    13437                 songs similar to blackout by boris            17   \n",
      "..     ...                                                ...           ...   \n",
      "422  14028  songs like good news by mac miller , preferrab...            11   \n",
      "423  14028  songs like good news by mac miller , preferrab...            24   \n",
      "424  14030  something along the lines of either the chain ...            49   \n",
      "425  14030  something along the lines of either the chain ...            29   \n",
      "426  14032       heavy bass x gothic rap like oxygen by bones            29   \n",
      "\n",
      "     end_offset          label  \n",
      "0            17   Artist_known  \n",
      "1            71  Artist_or_WoA  \n",
      "2            35            WoA  \n",
      "3            30         Artist  \n",
      "4            25            WoA  \n",
      "..          ...            ...  \n",
      "422          20            WoA  \n",
      "423          34         Artist  \n",
      "424          60  Artist_or_WoA  \n",
      "425          45  Artist_or_WoA  \n",
      "426          44  Artist_or_WoA  \n",
      "\n",
      "[427 rows x 5 columns]\n"
     ]
    }
   ],
   "source": [
    "music_ner_df = pd.read_csv('../data/music_ner.csv')\n",
    "def change_label(input_label):\n",
    "    input_label = input_label.replace(\"_deduced\", \"\")\n",
    "    return input_label\n",
    "music_ner_df[\"label\"] = music_ner_df[\"label\"].apply(change_label)\n",
    "music_ner_df[\"text\"] = music_ner_df[\"text\"].apply(lambda x: x.replace(\"|\", \",\"))\n",
    "print(music_ner_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "522624bd-8ba1-4c02-a282-995c7e34d5f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "ids = list(set(music_ner_df[\"id\"].values))\n",
    "docs = {}\n",
    "for id in ids:\n",
    "    entity_rows = music_ner_df.loc[music_ner_df['id'] == id]\n",
    "    text = entity_rows.head(1)[\"text\"].values[0]\n",
    "    doc = small_model(text)\n",
    "    ents = []\n",
    "    for index, row in entity_rows.iterrows():\n",
    "        label = row[\"label\"]\n",
    "        start = row[\"start_offset\"]\n",
    "        end = row[\"end_offset\"]\n",
    "        span = doc.char_span(start, end, label=label, alignment_mode=\"contract\")\n",
    "        ents.append(span)\n",
    "    doc.ents = ents\n",
    "    docs[doc.text] = doc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8a130f12-e3f6-472d-81bb-477fd67d2f65",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_file = \"../data/music_ner_bio.bio\"\n",
    "tag_mapping = {\"O\": 0, \"B-Artist\": 1, \"I-Artist\": 2, \"B-WoA\": 3, \"I-WoA\": 4}\n",
    "with open(data_file) as f:\n",
    "    data = f.read()\n",
    "tokens = []\n",
    "ner_tags = []\n",
    "spans = []\n",
    "sentences = data.split(\"\\n\\n\")\n",
    "for sentence in sentences:\n",
    "    words = []\n",
    "    tags = []\n",
    "    this_sentence_spans = []\n",
    "    word_tag_pairs = sentence.split(\"\\n\")\n",
    "    for pair in word_tag_pairs:\n",
    "        (word, tag) = pair.split(\"\\t\")\n",
    "        words.append(word)\n",
    "        tags.append(tag_mapping[tag])\n",
    "    sentence_text = \" \".join(words)\n",
    "    try:\n",
    "        doc = docs[sentence_text]\n",
    "    except:\n",
    "        pass\n",
    "    ent_dict = {}\n",
    "    for ent in doc.ents:\n",
    "        this_sentence_spans.append(f\"{ent.label_}: {ent.text}\")\n",
    "    tokens.append(words)\n",
    "    ner_tags.append(tags)\n",
    "    spans.append(this_sentence_spans)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "81b9adb8-42d9-4e1b-a536-cfbbf52a1931",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "539\n",
      "60\n"
     ]
    }
   ],
   "source": [
    "indices = range(0, len(spans))\n",
    "train, test = train_test_split(indices, test_size=0.1)\n",
    "train_tokens = []\n",
    "test_tokens = []\n",
    "train_ner_tags = []\n",
    "test_ner_tags = []\n",
    "train_spans = []\n",
    "test_spans = []\n",
    "for i, (token, ner_tag, span) in enumerate(zip(tokens, ner_tags, spans)):\n",
    "    if i in train:\n",
    "        train_tokens.append(token)\n",
    "        train_ner_tags.append(ner_tag)\n",
    "        train_spans.append(span)\n",
    "    else:\n",
    "        test_tokens.append(token)\n",
    "        test_ner_tags.append(ner_tag)\n",
    "        test_spans.append(span)        \n",
    "        \n",
    "print(len(train_spans))\n",
    "print(len(test_spans))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "afaf9488-01a0-45ee-967c-6f283d717bdf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                               tokens  \\\n",
      "0   [rap, hip, hop, that, sounds, like, crushed, u...   \n",
      "1   [andrew, huang, 300000, note, song, ,, any, ot...   \n",
      "2   [looking, for, more, sad, songs, in, the, same...   \n",
      "3                    [scott, pilgrim, vs, the, world]   \n",
      "4          [a, beautiful, suggestion, would, suffice]   \n",
      "5   [looking, for, a, very, specific, brand, of, f...   \n",
      "6   [looking, for, more, music, like, this, i, gue...   \n",
      "7                                  [something, weird]   \n",
      "8                              [hard, rock, suggests]   \n",
      "9   [songs, similar, to, just, my, imagination, by...   \n",
      "10  [something, like, m2us, magnolia, without, the...   \n",
      "11                              [a, song, like, this]   \n",
      "12  [looking, for, more, songs, similar, in, conce...   \n",
      "13  [similar, to, in, this, moment, or, pretty, re...   \n",
      "14          [clean, weight, training, playlist, help]   \n",
      "15  [know, any, glitchy, songs, like, virtual, sel...   \n",
      "16                          [i, need, a, 100th, song]   \n",
      "17                          [dark, covers, of, songs]   \n",
      "18  [playlist, dedicated, to, songs, similar, in, ...   \n",
      "19  [searching, for, songs, like, futures, mask, off]   \n",
      "20  [looking, for, pop, rock, power, songs, like, ...   \n",
      "21                                [chill, psych, pop]   \n",
      "22  [a, song, to, play, on, the, piano, and, sing,...   \n",
      "23  [similar, to, chris, stapleton, or, first, aid...   \n",
      "24  [looking, for, something, like, dragons, dogma...   \n",
      "25                    [songs, like, amour, plastique]   \n",
      "26                                      [need, music]   \n",
      "27  [looking, for, songs, similar, to, this, psych...   \n",
      "28                  [chill, piano, electronic, music]   \n",
      "29                               [rainy, days, songs]   \n",
      "30  [so, long, see, you, tomorrow, bombay, bicycle...   \n",
      "31  [contemporary, avant, garde, electronica, expe...   \n",
      "32  [song, similar, in, feeling, and, vibe, than, ...   \n",
      "33  [looking, for, music, similar, to, bishops, kn...   \n",
      "34  [looking, for, music, with, a, rhytmic, and, d...   \n",
      "35  [these, songs, are, not, that, well, known, bu...   \n",
      "36  [music, artists, similar, to, nujabes, atlas, ...   \n",
      "37         [irish, hip, hop, cloud, 404, 9life, 2020]   \n",
      "38     [creating, a, grade, 6, graduation, slideshow]   \n",
      "39  [suggestions, for, street, level, combat, them...   \n",
      "40                [rock, but, with, a, deep, message]   \n",
      "41  [im, working, on, a, trip, hop, downtempo, chi...   \n",
      "42        [song, similar, to, airport, bar, by, noah]   \n",
      "43                   [shamanic, experience, playlist]   \n",
      "44  [artists, like, smino, sudan, archives, fjk, j...   \n",
      "45  [looking, for, something, like, the, punk, ref...   \n",
      "46  [looking, for, songs, with, young, female, pop...   \n",
      "47  [trying, to, find, an, album, like, solitude, ...   \n",
      "48  [songs, similar, to, cindy, by, tammany, hall,...   \n",
      "49  [in, need, of, some, music, similar, to, the, ...   \n",
      "50  [after, a, long, music, rut, i, was, happy, to...   \n",
      "51  [im, searching, for, songs, that, are, like, p...   \n",
      "52  [looking, for, songs, that, follow, a, fun, cr...   \n",
      "53                   [best, rap, with, piano, in, it]   \n",
      "54  [looking, for, songs, like, the, nights, ,, av...   \n",
      "55  [indie, music, choir, harmonizing, with, lead,...   \n",
      "56            [songs, you, fit, my, summer, playlist]   \n",
      "57  [any, songs, you, know, of, that, have, a, sim...   \n",
      "58              [similar, music, as, sopor, aeternus]   \n",
      "59  [heavy, bass, x, gothic, rap, like, oxygen, by...   \n",
      "\n",
      "                                             ner_tags  \\\n",
      "0                      [0, 0, 0, 0, 0, 0, 3, 4, 0, 1]   \n",
      "1                [1, 2, 3, 4, 4, 0, 0, 0, 0, 0, 0, 0]   \n",
      "2   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 1, ...   \n",
      "3                                     [3, 4, 4, 4, 4]   \n",
      "4                                     [0, 0, 0, 0, 0]   \n",
      "5                         [0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
      "6             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
      "7                                              [0, 0]   \n",
      "8                                           [0, 0, 0]   \n",
      "9                         [0, 0, 0, 3, 4, 4, 0, 1, 2]   \n",
      "10                              [0, 0, 1, 3, 0, 0, 0]   \n",
      "11                                       [0, 0, 0, 0]   \n",
      "12               [0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 1, 2]   \n",
      "13                           [0, 0, 1, 2, 2, 0, 1, 2]   \n",
      "14                                    [0, 0, 0, 0, 0]   \n",
      "15                        [0, 0, 0, 0, 0, 1, 2, 3, 4]   \n",
      "16                                    [0, 0, 0, 0, 0]   \n",
      "17                                       [0, 0, 0, 0]   \n",
      "18  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, ...   \n",
      "19                              [0, 0, 0, 0, 1, 3, 4]   \n",
      "20                  [0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 4]   \n",
      "21                                          [0, 0, 0]   \n",
      "22   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
      "23                           [0, 0, 1, 2, 0, 1, 2, 2]   \n",
      "24                        [0, 0, 0, 0, 3, 4, 3, 4, 4]   \n",
      "25                                       [0, 0, 3, 4]   \n",
      "26                                             [0, 0]   \n",
      "27         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
      "28                                       [0, 0, 0, 0]   \n",
      "29                                          [0, 0, 0]   \n",
      "30         [3, 4, 4, 4, 4, 1, 2, 2, 0, 0, 0, 0, 0, 0]   \n",
      "31                              [0, 0, 0, 0, 0, 0, 0]   \n",
      "32                        [0, 0, 0, 0, 0, 0, 0, 1, 2]   \n",
      "33      [0, 0, 0, 0, 0, 3, 4, 4, 1, 2, 2, 0, 3, 4, 1]   \n",
      "34               [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
      "35  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...   \n",
      "36                              [0, 0, 0, 0, 1, 1, 0]   \n",
      "37                              [0, 0, 0, 1, 2, 3, 0]   \n",
      "38                                 [0, 0, 0, 0, 0, 0]   \n",
      "39                     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
      "40                                 [0, 0, 0, 0, 0, 0]   \n",
      "41  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...   \n",
      "42                              [0, 0, 0, 3, 4, 0, 1]   \n",
      "43                                          [0, 0, 0]   \n",
      "44                  [0, 0, 1, 1, 2, 1, 1, 2, 0, 1, 2]   \n",
      "45                     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
      "46                           [0, 0, 0, 0, 0, 0, 0, 0]   \n",
      "47  [0, 0, 0, 0, 0, 0, 3, 4, 0, 1, 2, 0, 0, 0, 0, ...   \n",
      "48                           [0, 0, 0, 3, 0, 1, 2, 2]   \n",
      "49  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 1, ...   \n",
      "50  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, ...   \n",
      "51                     [0, 0, 0, 0, 0, 0, 0, 1, 2, 3]   \n",
      "52                        [0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
      "53                                 [0, 0, 0, 0, 0, 0]   \n",
      "54                           [0, 0, 0, 0, 3, 4, 0, 1]   \n",
      "55                              [0, 0, 0, 0, 0, 0, 0]   \n",
      "56                                 [0, 0, 0, 0, 0, 0]   \n",
      "57                     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
      "58                                    [0, 0, 0, 1, 2]   \n",
      "59                        [0, 0, 0, 0, 0, 0, 3, 0, 1]   \n",
      "\n",
      "                                                spans  \\\n",
      "0                   [WoA: crushed up, Artist: future]   \n",
      "1       [Artist: andrew huang, WoA: 300000 note song]   \n",
      "2   [WoA: funeral, Artist: phoebe bridges, WoA: tr...   \n",
      "3         [Artist_or_WoA: scott pilgrim vs the world]   \n",
      "4                         [WoA: gnaw, Artist: alex g]   \n",
      "5                         [WoA: gnaw, Artist: alex g]   \n",
      "6                                  [Artist: lindécis]   \n",
      "7                                  [Artist: lindécis]   \n",
      "8      [WoA: little lion man, Artist: mumford & sons]   \n",
      "9   [WoA: just my imagination, Artist: the temptat...   \n",
      "10                      [Artist: m2us, WoA: magnolia]   \n",
      "11           [Artist: perry comos, WoA: anema e core]   \n",
      "12        [WoA: yesterday, Artist_known: the beatles]   \n",
      "13        [WoA: in this moment, WoA: pretty reckless]   \n",
      "14              [WoA: say it aint so, Artist: weezer]   \n",
      "15        [Artist_or_WoA: virtual self particle ants]   \n",
      "16  [WoA: devil devil, Artist: milck, WoA: prom qu...   \n",
      "17  [Artist_known: taylor swift, WoA: reputation, ...   \n",
      "18                               [Artist: the smiths]   \n",
      "19                   [Artist: futures, WoA: mask off]   \n",
      "20                                 [WoA: let em talk]   \n",
      "21                    [Artist_or_WoA: the mini ladds]   \n",
      "22                    [Artist_or_WoA: the mini ladds]   \n",
      "23  [Artist: chris stapleton, Artist_or_WoA: first...   \n",
      "24     [Artist: dragons, WoA: dogma into free dangan]   \n",
      "25                             [WoA: amour plastique]   \n",
      "26                           [Artist_or_WoA: crywank]   \n",
      "27                                        [WoA: coco]   \n",
      "28                                        [WoA: coco]   \n",
      "29                            [Artist: novos baianos]   \n",
      "30  [Artist_or_WoA: so long see you tomorrow bomba...   \n",
      "31  [Artist_or_WoA: so long see you tomorrow bomba...   \n",
      "32                               [WoA: grateful dead]   \n",
      "33  [Artist_or_WoA: bishops knife trick, Artist_or...   \n",
      "34                 [WoA: bowsprit, Artist: balmorhea]   \n",
      "35  [Artist_or_WoA: colde sunflower colde shh offo...   \n",
      "36                     [Artist_or_WoA: nujabes atlas]   \n",
      "37          [Artist_or_WoA: gsh, Artist_or_WoA: gaye]   \n",
      "38            [Artist: bon iver, Artist: iron & wine]   \n",
      "39            [Artist: bon iver, Artist: iron & wine]   \n",
      "40                                 [WoA: ach so gern]   \n",
      "41   [Artist: kid rocks, WoA: greatest show on earth]   \n",
      "42             [WoA: airport bar, Artist_known: noah]   \n",
      "43                                    [Artist: sweet]   \n",
      "44  [Artist: smino sudan, Artist: archives, Artist...   \n",
      "45              [WoA: an awesome wave, Artist: alt j]   \n",
      "46               [WoA_known: guardians of the galaxy]   \n",
      "47     [WoA: solitude standing, Artist: suzanne vega]   \n",
      "48             [WoA: cindy, Artist: tammany hall nyc]   \n",
      "49  [WoA: the likes, Artist_known: rage against th...   \n",
      "50                                [Artist: koi child]   \n",
      "51            [Artist_known: post malone, WoA: leave]   \n",
      "52  [Artist: princess and the frog, Artist: hadest...   \n",
      "53                [WoA: time to move, Artist: carmen]   \n",
      "54            [WoA: the nights, Artist_known: avicii]   \n",
      "55    [WoA: hold my liquor, Artist_known: kanye west]   \n",
      "56  [WoA: the piper at the gates of dawn, Artist_k...   \n",
      "57  [WoA: the piper at the gates of dawn, Artist_k...   \n",
      "58                    [Artist_or_WoA: sopor aeternus]   \n",
      "59                   [Artist_or_WoA: oxygen by bones]   \n",
      "\n",
      "                                                 text  \n",
      "0   rap hip hop that sounds like crushed up by future  \n",
      "1   andrew huang 300000 note song , any other song...  \n",
      "2   looking for more sad songs in the same vein as...  \n",
      "3                          scott pilgrim vs the world  \n",
      "4                a beautiful suggestion would suffice  \n",
      "5     looking for a very specific brand of folk music  \n",
      "6   looking for more music like this i guess its l...  \n",
      "7                                     something weird  \n",
      "8                                  hard rock suggests  \n",
      "9   songs similar to just my imagination by the te...  \n",
      "10   something like m2us magnolia without the dubstep  \n",
      "11                                   a song like this  \n",
      "12  looking for more songs similar in concept to y...  \n",
      "13       similar to in this moment or pretty reckless  \n",
      "14                clean weight training playlist help  \n",
      "15  know any glitchy songs like virtual self parti...  \n",
      "16                                i need a 100th song  \n",
      "17                               dark covers of songs  \n",
      "18  playlist dedicated to songs similar in feel an...  \n",
      "19          searching for songs like futures mask off  \n",
      "20  looking for pop rock power songs like this let...  \n",
      "21                                    chill psych pop  \n",
      "22  a song to play on the piano and sing along to ...  \n",
      "23        similar to chris stapleton or first aid kit  \n",
      "24  looking for something like dragons dogma into ...  \n",
      "25                         songs like amour plastique  \n",
      "26                                         need music  \n",
      "27  looking for songs similar to this psychedelic ...  \n",
      "28                       chill piano electronic music  \n",
      "29                                   rainy days songs  \n",
      "30  so long see you tomorrow bombay bicycle club a...  \n",
      "31  contemporary avant garde electronica experimen...  \n",
      "32  song similar in feeling and vibe than grateful...  \n",
      "33  looking for music similar to bishops knife tri...  \n",
      "34  looking for music with a rhytmic and driving b...  \n",
      "35  these songs are not that well known but what a...  \n",
      "36         music artists similar to nujabes atlas etc  \n",
      "37                 irish hip hop cloud 404 9life 2020  \n",
      "38            creating a grade 6 graduation slideshow  \n",
      "39  suggestions for street level combat themes , a...  \n",
      "40                       rock but with a deep message  \n",
      "41  im working on a trip hop downtempo chillout lo...  \n",
      "42                song similar to airport bar by noah  \n",
      "43                       shamanic experience playlist  \n",
      "44  artists like smino sudan archives fjk jessie r...  \n",
      "45  looking for something like the punk refrain in...  \n",
      "46  looking for songs with young female popgeneric...  \n",
      "47  trying to find an album like solitude standing...  \n",
      "48         songs similar to cindy by tammany hall nyc  \n",
      "49  in need of some music similar to the likes of ...  \n",
      "50  after a long music rut i was happy to find koi...  \n",
      "51  im searching for songs that are like post malo...  \n",
      "52    looking for songs that follow a fun creepy vibe  \n",
      "53                          best rap with piano in it  \n",
      "54         looking for songs like the nights , avicii  \n",
      "55     indie music choir harmonizing with lead singer  \n",
      "56                   songs you fit my summer playlist  \n",
      "57    any songs you know of that have a similar style  \n",
      "58                    similar music as sopor aeternus  \n",
      "59       heavy bass x gothic rap like oxygen by bones  \n"
     ]
    }
   ],
   "source": [
    "training_df = pd.DataFrame({\"tokens\":train_tokens, \"ner_tags\": train_ner_tags, \"spans\": train_spans})\n",
    "test_df = pd.DataFrame({\"tokens\": test_tokens, \"ner_tags\": test_ner_tags, \"spans\": test_spans})\n",
    "training_df[\"text\"] = training_df[\"tokens\"].apply(lambda x: \" \".join(x))\n",
    "test_df[\"text\"] = test_df[\"tokens\"].apply(lambda x: \" \".join(x))\n",
    "training_df.dropna()\n",
    "test_df.dropna()\n",
    "print(test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7842b146-d4f2-47d8-9430-ce8bb3b97bf4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb537050d0cf4305bd89052f0065613c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'ner_tags': Sequence(feature=ClassLabel(names=['O', 'B-Artist', 'I-Artist', 'B-WoA', 'I-WoA'], id=None), length=-1, id=None), 'spans': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'text': Value(dtype='string', id=None)}\n",
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['tokens', 'ner_tags', 'spans', 'text'],\n",
      "        num_rows: 539\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['tokens', 'ner_tags', 'spans', 'text'],\n",
      "        num_rows: 60\n",
      "    })\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')\n",
    "#model = AutoModel.from_pretrained(\"bert-base-cased\")\n",
    "features = Features({'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), \n",
    "            'ner_tags': Sequence(feature=ClassLabel(names=['O', 'B-Artist', 'I-Artist', 'B-WoA', 'I-WoA'], id=None), length=-1, id=None), \n",
    "            'spans': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),\n",
    "            'text': Value(dtype='string', id=None)\n",
    "                    })\n",
    "training_dataset = Dataset.from_pandas(training_df, features=features)\n",
    "test_dataset = Dataset.from_pandas(test_df, features=features)\n",
    "dataset = DatasetDict({\"train\":training_dataset, \"test\":test_dataset}) \n",
    "print(dataset[\"train\"].features)\n",
    "label_names = dataset[\"train\"].features[\"ner_tags\"].feature.names\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "47e6a784-e067-4fa7-a3e5-d300c423df39",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_adjust_labels(all_samples_per_split):\n",
    "    tokenized_samples = tokenizer.batch_encode_plus(all_samples_per_split[\"text\"])\n",
    "    total_adjusted_labels = []\n",
    "    for k in range(0, len(tokenized_samples[\"input_ids\"])):\n",
    "        prev_wid = -1\n",
    "        word_ids_list = tokenized_samples.word_ids(batch_index=k)\n",
    "        existing_label_ids = all_samples_per_split[\"ner_tags\"][k]\n",
    "        i = -1\n",
    "        adjusted_label_ids = []\n",
    "        for wid in word_ids_list:\n",
    "            if (wid is None):\n",
    "                adjusted_label_ids.append(-100)\n",
    "            elif (wid != prev_wid):\n",
    "                i = i + 1\n",
    "                adjusted_label_ids.append(existing_label_ids[i])\n",
    "                prev_wid = wid\n",
    "            else:\n",
    "                label_name = label_names[existing_label_ids[i]]\n",
    "                adjusted_label_ids.append(existing_label_ids[i])\n",
    "        total_adjusted_labels.append(adjusted_label_ids)\n",
    "    tokenized_samples[\"labels\"] = total_adjusted_labels\n",
    "    return tokenized_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6e914182-53c0-4b38-942f-366f0e0af5a4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aadc34d4d4734b3293c334392ddd2c79",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/539 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5d99799c48f047b6bf893b53d3347132",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/60 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tokenized_dataset = dataset.map(tokenize_adjust_labels, batched=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6f80341e-71c3-4e16-a17c-ae9077443b4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_collator = DataCollatorForTokenClassification(tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "9462b75d-316b-4e8b-ad68-2e83d235b946",
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = load(\"seqeval\")\n",
    "def compute_metrics(data):\n",
    "    predictions, labels = data\n",
    "    predictions = np.argmax(predictions, axis=2)\n",
    "\n",
    "    data = zip(predictions, labels)\n",
    "    data = [[(p, l) for (p, l) in zip(prediction, label) if l != -100] for prediction, label in data]\n",
    "\n",
    "    true_predictions = [[label_names[p] for (p, l) in data_point] for data_point in data]\n",
    "    true_labels = [[label_names[l] for (p, l) in data_point] for data_point in data]\n",
    "    \n",
    "    results = metric.compute(predictions=true_predictions, references=true_labels)\n",
    "    flat_results = {\n",
    "        \"overall_precision\": results[\"overall_precision\"],\n",
    "        \"overall_recall\": results[\"overall_recall\"],\n",
    "        \"overall_f1\": results[\"overall_f1\"],\n",
    "        \"overall_accuracy\": results[\"overall_accuracy\"],\n",
    "    }\n",
    "    for k in results.keys():\n",
    "      if (k not in flat_results.keys()):\n",
    "        flat_results[k + \"_f1\"] = results[k][\"f1\"]\n",
    "\n",
    "    return flat_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "62a1e3af-ba38-429f-86ac-6b42d0a6ceaf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='238' max='238' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [238/238 00:25, Epoch 7/7]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=238, training_loss=0.25769581514246326, metrics={'train_runtime': 25.8951, 'train_samples_per_second': 145.703, 'train_steps_per_second': 9.191, 'total_flos': 49438483110900.0, 'train_loss': 0.25769581514246326, 'epoch': 7.0})"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Train model\n",
    "model = AutoModelForTokenClassification.from_pretrained('bert-base-uncased', num_labels=len(label_names))\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=\"./fine_tune_bert_output\",\n",
    "    evaluation_strategy=\"steps\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=7,\n",
    "    weight_decay=0.01,\n",
    "    logging_steps = 1000,\n",
    "    run_name = \"ep_10_tokenized_11\",\n",
    "    save_strategy='no'\n",
    ")\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_dataset[\"train\"],\n",
    "    eval_dataset=tokenized_dataset[\"test\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=compute_metrics\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "aaaafe40-8de7-45ea-8ac9-1aae4310531f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='4' max='4' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [4/4 00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.28670933842658997,\n",
       " 'eval_overall_precision': 0.6470588235294118,\n",
       " 'eval_overall_recall': 0.7096774193548387,\n",
       " 'eval_overall_f1': 0.6769230769230768,\n",
       " 'eval_overall_accuracy': 0.9153605015673981,\n",
       " 'eval_Artist_f1': 0.761904761904762,\n",
       " 'eval_WoA_f1': 0.5217391304347826,\n",
       " 'eval_runtime': 0.3239,\n",
       " 'eval_samples_per_second': 185.262,\n",
       " 'eval_steps_per_second': 12.351,\n",
       " 'epoch': 7.0}"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Evaluate model\n",
    "trainer.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "70dee8f3-c3e0-400b-a44a-73a2e73ae60e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save model\n",
    "trainer.save_model(\"../models/bert_fine_tuned\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "350ed6fb-d9fe-4091-8c94-36bf86ccb900",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use model\n",
    "model = AutoModelForTokenClassification.from_pretrained(\"../models/bert_fine_tuned\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"../models/bert_fine_tuned\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "4fe4033b-f751-4141-8c29-26f677e7c56c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'entity_group': 'LABEL_0',\n",
       "  'score': 0.9991929,\n",
       "  'word': 'music similar to',\n",
       "  'start': 0,\n",
       "  'end': 16},\n",
       " {'entity_group': 'LABEL_1',\n",
       "  'score': 0.8970744,\n",
       "  'word': 'morphine robocobra',\n",
       "  'start': 17,\n",
       "  'end': 35},\n",
       " {'entity_group': 'LABEL_2',\n",
       "  'score': 0.5060059,\n",
       "  'word': 'quartet',\n",
       "  'start': 36,\n",
       "  'end': 43},\n",
       " {'entity_group': 'LABEL_0',\n",
       "  'score': 0.9988042,\n",
       "  'word': '| featuring elements like saxophone prominent bass',\n",
       "  'start': 44,\n",
       "  'end': 94}]"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = \"music similar to morphine robocobra quartet | featuring elements like saxophone prominent bass\"\n",
    "from transformers import pipeline\n",
    "pipe = pipeline(task=\"token-classification\", model=model.to(\"cpu\"), tokenizer=tokenizer, aggregation_strategy=\"simple\")\n",
    "pipe(text)\n",
    "# tag_mapping = {\"O\": 0, \"B-Artist\": 1, \"I-Artist\": 2, \"B-WoA\": 3, \"I-WoA\": 4}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4324185a-c9a7-4d12-96ca-bc309ece74d1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
